{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-27T09:44:13.362955Z",
     "start_time": "2019-11-27T09:44:07.902764Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/eagle/anaconda3/envs/dl_py3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/Users/eagle/anaconda3/envs/dl_py3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/Users/eagle/anaconda3/envs/dl_py3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/Users/eagle/anaconda3/envs/dl_py3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/Users/eagle/anaconda3/envs/dl_py3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/Users/eagle/anaconda3/envs/dl_py3/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "/Users/eagle/anaconda3/envs/dl_py3/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/Users/eagle/anaconda3/envs/dl_py3/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/Users/eagle/anaconda3/envs/dl_py3/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/Users/eagle/anaconda3/envs/dl_py3/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/Users/eagle/anaconda3/envs/dl_py3/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/Users/eagle/anaconda3/envs/dl_py3/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "/Users/eagle/anaconda3/envs/dl_py3/lib/python3.6/site-packages/konlpy/tag/_okt.py:16: UserWarning: \"Twitter\" has changed to \"Okt\" since KoNLPy v0.4.5.\n",
      "  warn('\"Twitter\" has changed to \"Okt\" since KoNLPy v0.4.5.')\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import, division, print_function, unicode_literals\n",
    "import json\n",
    "import pickle\n",
    "import torch\n",
    "from gluonnlp.data import SentencepieceTokenizer\n",
    "from model.net import KobertCRF\n",
    "from data_utils.utils import Config\n",
    "from data_utils.vocab_tokenizer import Tokenizer\n",
    "from data_utils.pad_sequence import keras_pad_fn\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-27T09:44:13.406972Z",
     "start_time": "2019-11-27T09:44:13.370406Z"
    }
   },
   "outputs": [],
   "source": [
    "class DecoderFromNamedEntitySequence():\n",
    "    def __init__(self, tokenizer, index_to_ner):\n",
    "        self.tokenizer = tokenizer\n",
    "        self.index_to_ner = index_to_ner\n",
    "\n",
    "    def __call__(self, list_of_input_ids, list_of_pred_ids):\n",
    "        input_token = self.tokenizer.decode_token_ids(list_of_input_ids)[0]\n",
    "        pred_ner_tag = [self.index_to_ner[pred_id] for pred_id in list_of_pred_ids[0]]\n",
    "\n",
    "        # ----------------------------- parsing list_of_ner_word ----------------------------- #\n",
    "        list_of_ner_word = []\n",
    "        entity_word, entity_tag, prev_entity_tag = \"\", \"\", \"\"\n",
    "        for i, pred_ner_tag_str in enumerate(pred_ner_tag):\n",
    "            if \"B-\" in pred_ner_tag_str:\n",
    "                entity_tag = pred_ner_tag_str[-3:]\n",
    "\n",
    "                if prev_entity_tag != entity_tag and prev_entity_tag != \"\":\n",
    "                    list_of_ner_word.append({\"word\": entity_word.replace(\"▁\", \" \"), \"tag\": prev_entity_tag, \"prob\": None})\n",
    "\n",
    "                entity_word = input_token[i]\n",
    "                prev_entity_tag = entity_tag\n",
    "            elif \"I-\"+entity_tag in pred_ner_tag_str:\n",
    "                entity_word += input_token[i]\n",
    "            else:\n",
    "                if entity_word != \"\" and entity_tag != \"\":\n",
    "                    list_of_ner_word.append({\"word\":entity_word.replace(\"▁\", \" \"), \"tag\":entity_tag, \"prob\":None})\n",
    "                entity_word, entity_tag, prev_entity_tag = \"\", \"\", \"\"\n",
    "\n",
    "\n",
    "        # ----------------------------- parsing decoding_ner_sentence ----------------------------- #\n",
    "        decoding_ner_sentence = \"\"\n",
    "        is_prev_entity = False\n",
    "        prev_entity_tag = \"\"\n",
    "        is_there_B_before_I = False\n",
    "\n",
    "        for i, (token_str, pred_ner_tag_str) in enumerate(zip(input_token, pred_ner_tag)):\n",
    "            if i == 0 or i == len(pred_ner_tag)-1: # remove [CLS], [SEP]\n",
    "                continue\n",
    "            token_str = token_str.replace('▁', ' ')  # '▁' 토큰을 띄어쓰기로 교체\n",
    "\n",
    "            if 'B-' in pred_ner_tag_str:\n",
    "                if is_prev_entity is True:\n",
    "                    decoding_ner_sentence += ':' + prev_entity_tag+ '>'\n",
    "\n",
    "                if token_str[0] == ' ':\n",
    "                    token_str = list(token_str)\n",
    "                    token_str[0] = ' <'\n",
    "                    token_str = ''.join(token_str)\n",
    "                    decoding_ner_sentence += token_str\n",
    "                else:\n",
    "                    decoding_ner_sentence += '<' + token_str\n",
    "                is_prev_entity = True\n",
    "                prev_entity_tag = pred_ner_tag_str[-3:] # 첫번째 예측을 기준으로 하겠음\n",
    "                is_there_B_before_I = True\n",
    "\n",
    "            elif 'I-' in pred_ner_tag_str:\n",
    "                decoding_ner_sentence += token_str\n",
    "\n",
    "                if is_there_B_before_I is True: # I가 나오기전에 B가 있어야하도록 체크\n",
    "                    is_prev_entity = True\n",
    "            else:\n",
    "                if is_prev_entity is True:\n",
    "                    decoding_ner_sentence += ':' + prev_entity_tag+ '>' + token_str\n",
    "                    is_prev_entity = False\n",
    "                    is_there_B_before_I = False\n",
    "                else:\n",
    "                    decoding_ner_sentence += token_str\n",
    "\n",
    "        return list_of_ner_word, decoding_ner_sentence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-27T09:44:13.441346Z",
     "start_time": "2019-11-27T09:44:13.414934Z"
    }
   },
   "outputs": [],
   "source": [
    "def main():\n",
    "    model_dir = Path('./experiments/base_model_with_crf')\n",
    "    model_config = Config(json_path=model_dir / 'config.json')\n",
    "\n",
    "    # load vocab & tokenizer\n",
    "    tok_path = \"./ptr_lm_model/tokenizer_78b3253a26.model\"\n",
    "    ptr_tokenizer = SentencepieceTokenizer(tok_path)\n",
    "\n",
    "    with open(model_dir / \"vocab.pkl\", 'rb') as f:\n",
    "        vocab = pickle.load(f)\n",
    "    tokenizer = Tokenizer(vocab=vocab, split_fn=ptr_tokenizer, pad_fn=keras_pad_fn, maxlen=model_config.maxlen)\n",
    "\n",
    "    # load ner_to_index.json\n",
    "    with open(model_dir / \"ner_to_index.json\", 'rb') as f:\n",
    "        ner_to_index = json.load(f)\n",
    "        index_to_ner = {v: k for k, v in ner_to_index.items()}\n",
    "\n",
    "    # model\n",
    "    model = KobertCRF(config=model_config, num_classes=len(ner_to_index), vocab=vocab)\n",
    "\n",
    "    # load\n",
    "    model_dict = model.state_dict()\n",
    "    checkpoint = torch.load(\"./experiments/base_model_with_crf/best-epoch-16-step-1500-acc-0.993.bin\", map_location=torch.device('cpu'))\n",
    "    # checkpoint = torch.load(\"./experiments/base_model_with_crf_val/best-epoch-12-step-1000-acc-0.960.bin\", map_location=torch.device('cpu'))\n",
    "    convert_keys = {}\n",
    "    for k, v in checkpoint['model_state_dict'].items():\n",
    "        new_key_name = k.replace(\"module.\", '')\n",
    "        if new_key_name not in model_dict:\n",
    "            print(\"{} is not int model_dict\".format(new_key_name))\n",
    "            continue\n",
    "        convert_keys[new_key_name] = v\n",
    "\n",
    "    model.load_state_dict(convert_keys)\n",
    "    model.eval()\n",
    "    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "    model.to(device)\n",
    "    decoder_from_res = DecoderFromNamedEntitySequence(tokenizer=tokenizer, index_to_ner=index_to_ner)\n",
    "\n",
    "    while(True):\n",
    "        input_text = input('input> ')\n",
    "        if input_text == 'end':\n",
    "            break\n",
    "        \n",
    "        list_of_input_ids = tokenizer.list_of_string_to_list_of_cls_sep_token_ids([input_text])\n",
    "        x_input = torch.tensor(list_of_input_ids).long()\n",
    "        list_of_pred_ids = model(x_input)\n",
    "\n",
    "        list_of_ner_word, decoding_ner_sentence = decoder_from_res(list_of_input_ids=list_of_input_ids, list_of_pred_ids=list_of_pred_ids)\n",
    "        print(\"output>\", decoding_ner_sentence)\n",
    "        print(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-27T09:45:47.655369Z",
     "start_time": "2019-11-27T09:44:13.443848Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input> 지난달 28일 수원에 살고 있는 윤주성 연구원은 코엑스(서울 삼성역)에서 개최되는 DEVIEW 2019 Day1에 참석했다. LaRva팀의 '엄~청 큰 언어 모델 공장 가동기!' 세션을 들으며 언어모델을 학습시킬때 multi-GPU, TPU 모두 써보고 싶다는 생각을 했다.\n",
      "output>  지난달 <28일:DAT> <수원:LOC>에 살고 있는 <윤주성:PER> 연구원은 <코엑스:LOC>(<서울:LOC> <삼성역:LOC>)에서 개최되는 <DEVIEW 2019 Day1:POH>에 참석했다. <LaRva팀:ORG>의 '엄~청 큰 언어 모델 공장 가동기!' 세션을 들으며 언어모델을 학습시킬때 multi-GPU, TPU 모두 써보고 싶다는 생각을 했다.\n",
      "\n",
      "input> 문재인 대통령은 28일 서울 코엑스에서 열린 ‘데뷰 (Deview) 2019’ 행사에 참석해 젊은 개발자들을 격려하면서 우리 정부의 인공지능 기본구상을 내놓았다.  출처 : 미디어오늘 (http://www.mediatoday.co.kr)\n",
      "output>  <문재인:PER> 대통령은 <28일:DAT> <서울 코엑스:LOC>에서 열린 ‘<데뷰 (Deview) 2019:POH>’ 행사에 참석해 젊은 개발자들을 격려하면서 우리 정부의 인공지능 기본구상을 내놓았다. 출처 : <미디어오늘:POH> (<http://www.mediatoday.co.kr:POH>)\n",
      "\n",
      "input> SKTBrain에서 KoBERT 모델을 공개해준 덕분에 BERT-CRF 기반 개체명인식기를 쉽게 개발할 수 있었다.\n",
      "output>  <SKTBrain:ORG>에서 <KoBERT:POH> 모델을 공개해준 덕분에 <BERT-CRF:POH> 기반 개체명인식기를 쉽게 개발할 수 있었다.\n",
      "\n",
      "input> 터미네이터: 다크 페이트 (Terminator: Dark Fate)는 2019년 개봉한 미국의 SF, 액션영화이다. 1991년 영화 터미네이터 2: 심판의 날 이후 28년 만에 제임스 카메론이 제작자로서 시리즈에 복귀한 작품이다. 린다 해밀턴이 사라 코너 역으로 돌아오면서 아널드 슈워제네거와 함께 주연을 맡았다.\n",
      "output>  <터미네이터::POH> <다크 페이트:POH> (<Terminator: Dark Fate:POH>)는 <2019년:DAT> 개봉한 미국의 SF, 액션영화이다. <1991년:DAT> 영화 <터미네이터 2: 심판의 날:POH> 이후 <28년:NOH> 만에 <제임스 카메론:PER>이 제작자로서 시리즈에 복귀한 작품이다. <린다 해밀턴:PER>이 <사라 코너:PER> 역으로 돌아오면서 <아널드 슈워제네거:PER>와 함께 주연을 맡았다.\n",
      "\n",
      "input> [뉴스토마토 김희경 기자] 영화 '터미네이터: 다크 페이트'(감독 팀 밀러)가 박스오피스 1위는 물론 전체 예매율 1위를 차지했다. 시리즈 최고 오프닝 스코어 경신과 함께 겹경사다.\n",
      "output>  <[뉴스토마토:ORG> <김희경:PER> 기자] 영화 '<터미네이터: 다크 페이트:POH>'(감독 <팀 밀러:PER>)가 <박스오피스:ORG> <1위는:NOH> 물론 전체 예매율 <1위를:NOH> 차지했다. 시리즈 최고 오프닝 스코어 경신과 함께 겹경사다.\n",
      "\n",
      "input> 전 세계 최고의 기대작 <어벤져스> 시리즈의 압도적 대미를 장식할 <어벤져스: 엔드게임>이 지난 4월 14일(일)과 15일(월) 양일간 진행된 대한민국 내한 행사를 성공적으로 마무리 지었다. <어벤져스: 엔드게임>의 주역 로버트 다우니 주니어, 제레미 레너, 브리 라슨, 안소니 루소&조 루소 감독, 트린 트랜 프로듀서, 케빈 파이기 마블 스튜디오 대표까지 방문하여 특별한 대한민국 사랑을 뽐냈다.\n",
      "output>  전 세계 최고의 기대작 <<어벤져스:POH>> 시리즈의 압도적 대미를 장식할 <<어벤져스: 엔드게임:POH>>이 지난 <4월 14일:DAT>(<일:DAT>)과 <15일:DAT>(<월:DAT>) <양일간:NOH> 진행된 <대한민국:LOC> 내한 행사를 성공적으로 마무리 지었다. <<어벤져스: 엔드게임:POH>>의 주역 <로버트 다우니 주니어:PER>, <제레미 레너:PER>, <브리 라슨:PER>, <안소니 루소:PER>&<조 루소:PER> 감독, <트린 트랜:PER> 프로듀서, <케빈 파이기:PER> <마블 스튜디오:ORG> 대표까지 방문하여 특별한 <대한민국:ORG> 사랑을 뽐냈다.\n",
      "\n",
      "input> 영화 '겨울왕국2'의 이현민 애니메이션 슈퍼바이저가 SBS '나이트라인'에 출연해 다양한 이야기를 전했다. 숨겨진 과거의 비밀과 새로운 운명을 찾기 위해 모험을 떠나는 '엘사'와 '안나'의 이야기를 담은 작품이다.\n",
      "output>  영화 '<겨울왕국2:POH>'의 <이현민:PER> 애니메이션 슈퍼바이저가 <SBS:POH> '<나이트라인:POH>'에 출연해 다양한 이야기를 전했다. 숨겨진 과거의 비밀과 새로운 운명을 찾기 위해 모험을 떠나는 '<엘사:PER>'와 '<안나:PER>'의 이야기를 담은 작품이다.\n",
      "\n",
      "input> 네이버(NAVER (167,000원▲ 3,000 1.83%)) 금융 계열사인 ‘네이버파이낸셜’이 1일 출범했다. 네이버파이낸셜은 기존 결제·송금 서비스를 하던 ‘네이버페이’를 분사해 설립한 회사다. 네이버파이낸셜은 외연을 확장해 ‘네이버 통장’을 선보이고, 이어 주식, 보험, 예·적금, 신용카드 서비스도 출시한다는 계획이다.\n",
      "output>  <네이버:ORG>(<NAVER:ORG> <(167,000원:MNY>▲ <3,000:MNY> 1.83%)) 금융 계열사인 ‘<네이버파이낸셜:ORG>’이 <1:DAT>일 출범했다. <네이버파이낸셜:ORG>은 기존 결제·송금 서비스를 하던 ‘<네이버페이:ORG>’를 분사해 설립한 회사다. <네이버파이낸셜:ORG>은 외연을 확장해 ‘<네이버:ORG> 통장’을 선보이고, 이어 주식, 보험, 예·적금, 신용카드 서비스도 출시한다는 계획이다.\n",
      "\n",
      "input> 이동륜 KB증권 연구원은 \"카카오페이를 통한 거래대금이 올해 상반기 약 22조원으로 지난해 연간 거래액을 초과 달성하는 등 안정적으로 성장하고 있고, 청구서·멤버십·간편보험 등 생활밀착형 서비스로 사업 확장이 진행되고 있다\"며 \"내년 중에는 분기 기준 흑자 전환이 예상된다\"고 말했다.\n",
      "output>  <이동륜:PER> <KB증권:ORG> 연구원은 \"<카카오페이:POH>를 통한 거래대금이 올해 상반기 약 <22조원:MNY>으로 지난해 연간 거래액을 초과 달성하는 등 안정적으로 성장하고 있고, 청구서·멤버십·간편보험 등 생활밀착형 서비스로 사업 확장이 진행되고 있다\"며 \"내년 중에는 분기 기준 흑자 전환이 예상된다\"고 말했다.\n",
      "\n",
      "input> 엔씨소프트가 자연어처리(NPL) 개발 스타트업인 '스캐터랩'에 전략적 투자를 단행한다. 국내 대표적인 온라인 게임회사 엔씨소프트의 투자가 게임 업계에 어떠한 변화를 가져올 지 주목된다. 12일 업계에 따르면 스캐터랩은 벤처캐피탈 등을 대상으로 시리즈 B 규모의 투자 유치 작업을 진행하고 있다.\n",
      "output>  <엔씨소프트:ORG>가 자연어처리(NPL) 개발 스타트업인 '<스캐터랩:ORG>'에 전략적 투자를 단행한다. 국내 대표적인 온라인 게임회사 <엔씨소프트:ORG>의 투자가 게임 업계에 어떠한 변화를 가져올 지 주목된다. <12일:DAT> 업계에 따르면 <스캐터랩:ORG>은 벤처캐피탈 등을 대상으로 <시리즈:NOH> B 규모의 투자 유치 작업을 진행하고 있다.\n",
      "\n",
      "input> ‘모든 단점은 장점이 될수 있다'  (Lionel Andres Messi)\n",
      "output>  ‘모든 단점은 장점이 될수 있다' (<Lionel Andres Messi:POH>)\n",
      "\n",
      "input> end\n"
     ]
    }
   ],
   "source": [
    "main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl_py3",
   "language": "python",
   "name": "dl_py3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}